include_directories(header common memory_pool)
include_directories(SYSTEM ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

aux_source_directory(common COMMON_SRC)
aux_source_directory(cuda_common CUDA_COMMON_SRC)
aux_source_directory(ops OPS_SRC)
aux_source_directory(dnnl_ops DNNL_OPS_SRC)

add_library(c_runtime_api SHARED)

add_custom_target(hetu DEPENDS c_runtime_api)

if(${HETU_COMPILE_GPU})
    find_package(CUDNN 7.5 REQUIRED)
    find_package(CUB REQUIRED)
    find_package(THRUST REQUIRED)
    if (${CUDNN_VERSION} VERSION_GREATER_EQUAL "8")
        add_definitions(-DCUDNN8)
    endif()
    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
        # if cuda 11, the function cusparseScsrmv is undefined
        list(REMOVE_ITEM OPS_SRC "ops/CuSparseCsrmv.cu")
    endif()
    add_definitions(-DDEVICE_GPU)
    target_sources(c_runtime_api PUBLIC ${COMMON_SRC} ${CUDA_COMMON_SRC} ${OPS_SRC})
    target_link_libraries(c_runtime_api cudart cublas cusparse curand)
    target_link_libraries(c_runtime_api ${CUDNN_LIBRARY_PATH})
    target_include_directories(c_runtime_api PUBLIC ${CUDNN_INCLUDE_PATH})
    target_include_directories(c_runtime_api PUBLIC cuda_common)
    target_include_directories(c_runtime_api PRIVATE ${CUB_INCLUDE_DIR} ${THRUST_INCLUDE_DIR})
    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
        set(HETU_CUDA_ARCH 52 60 61 70 75 80 86)
    else()
        set(HETU_CUDA_ARCH 50 52 60 61 70 75)
    endif()
    set_property(TARGET c_runtime_api PROPERTY CUDA_ARCHITECTURES ${HETU_CUDA_ARCH})
endif()
if(${HETU_COMPILE_MKL})
    find_package(MKL)
    if(NOT MKL_FOUND)
        FetchContent_Declare(mkl URL https://github.com/intel/mkl-dnn/archive/v1.6.1.tar.gz)
        message(STATUS "Preparing mkl-dnn ...")
        FetchContent_MakeAvailable(mkl)
        message(STATUS "Finish mkl-dnn.")
        target_link_libraries(c_runtime_api dnnl)
    else()
        target_include_directories(c_runtime_api PUBLIC ${DNNL_INCLUDE_DIR})
        target_include_directories(c_runtime_api PUBLIC ${DNNL_BUILD_INCLUDE_DIR})
        target_link_libraries(c_runtime_api ${DNNL_LIBRARY})
    endif()
    target_sources(c_runtime_api PUBLIC ${COMMON_SRC} ${DNNL_OPS_SRC})
endif()

if(${HETU_ALLREDUCE})
    add_subdirectory(communication)
endif()

if(${HETU_CACHE})
    add_subdirectory(hetu_cache)
endif()
